{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n\n# Compiling ResNet with dynamic shapes using the `torch.compile` backend\n\nThis interactive script is intended as a sample of the Torch-TensorRT workflow with `torch.compile` on a ResNet model.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Imports and Model Definition\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch\nimport torch_tensorrt\nimport torchvision.models as models"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Initialize model with half precision and sample inputs\nmodel = models.resnet18(pretrained=True).to(\"cuda\").half().eval()\ninputs = [torch.randn((1, 3, 224, 224)).to(\"cuda\").half()]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Optional Input Arguments to `torch_tensorrt.compile`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Enabled precision for TensorRT optimization\nenabled_precisions = {torch.half}\n\n\n# Workspace size for TensorRT\nworkspace_size = 20 << 30\n\n# Maximum number of TRT Engines\n# (Lower value allows more graph segmentation)\nmin_block_size = 7\n\n# Operations to Run in Torch, regardless of converter support\ntorch_executed_ops = {}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Compilation with `torch_tensorrt.compile`\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Build and compile the model with torch.compile, using Torch-TensorRT backend\noptimized_model = torch_tensorrt.compile(\n    model,\n    ir=\"torch_compile\",\n    inputs=inputs,\n    enabled_precisions=enabled_precisions,\n    workspace_size=workspace_size,\n    min_block_size=min_block_size,\n    torch_executed_ops=torch_executed_ops,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Equivalently, we could have run the above via the torch.compile frontend, as so:\n`optimized_model = torch.compile(model, backend=\"torch_tensorrt\", options={\"enabled_precisions\": enabled_precisions, ...}); optimized_model(*inputs)`\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Inference\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Does not cause recompilation (same batch size as input)\nnew_inputs = [torch.randn((1, 3, 224, 224)).to(\"cuda\").half()]\nwith torch.no_grad():\n    new_outputs = optimized_model(*new_inputs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# Does cause recompilation (new batch size)\nnew_batch_size_inputs = [torch.randn((8, 3, 224, 224)).to(\"cuda\").half()]\nwith torch.no_grad():\n    new_batch_size_outputs = optimized_model(*new_batch_size_inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Avoid recompilation by specifying dynamic shapes before Torch-TRT compilation\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# The following code illustrates the workflow using ir=torch_compile (which uses torch.compile under the hood)\ninputs_bs8 = torch.randn((8, 3, 224, 224)).to(\"cuda\").half()\n# This indicates dimension 0 of inputs_bs8 is dynamic whose range of values is [2, 16]\ntorch._dynamo.mark_dynamic(inputs_bs8, 0, min=2, max=16)\noptimized_model = torch_tensorrt.compile(\n    model,\n    ir=\"torch_compile\",\n    inputs=inputs_bs8,\n    enabled_precisions=enabled_precisions,\n    workspace_size=workspace_size,\n    min_block_size=min_block_size,\n    torch_executed_ops=torch_executed_ops,\n)\n\nwith torch.no_grad():\n    outputs_bs8 = optimized_model(inputs_bs8)\n\n# No recompilation happens for batch size = 12\ninputs_bs12 = torch.randn((12, 3, 224, 224)).to(\"cuda\").half()\nwith torch.no_grad():\n    outputs_bs12 = optimized_model(inputs_bs12)\n\n# The following code illustrates the workflow using ir=dynamo (which uses torch.export APIs under the hood)\n# dynamic shapes for any inputs are specified using torch_tensorrt.Input API\ncompile_spec = {\n    \"inputs\": [\n        torch_tensorrt.Input(\n            min_shape=(1, 3, 224, 224),\n            opt_shape=(8, 3, 224, 224),\n            max_shape=(16, 3, 224, 224),\n            dtype=torch.half,\n        )\n    ],\n    \"enabled_precisions\": enabled_precisions,\n    \"ir\": \"dynamo\",\n}\ntrt_model = torch_tensorrt.compile(model, **compile_spec)\n\n# No recompilation happens for batch size = 12\ninputs_bs12 = torch.randn((12, 3, 224, 224)).to(\"cuda\").half()\nwith torch.no_grad():\n    outputs_bs12 = trt_model(inputs_bs12)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.11.14"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}